Skip to content

Add diffusion model implementation #408

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 59 commits into
base: dev
Choose a base branch
from
Draft

Add diffusion model implementation #408

wants to merge 59 commits into from

Conversation

vpratz
Copy link
Collaborator

@vpratz vpratz commented Apr 13, 2025

This PR adds a diffusion model implementation for use as an inference network, as discussed in #403. It implements the design introduced as "EDM" in [1]. The overall structure is taken from the FlowMatching class.

@arrjon @niels-leif-bracher I would appreciate if you take a look and make suggestions regarding how we can incorporate the other diffusion model variants as well. For now, I chose to only expose the sigma_data parameter to the end user, and keep everything else private. This should enable us to also change the internals later on and incrementally add new functionality.

Please let me know how we want to proceed and how much capacity you have to move this forward, so that we can decide whether we want to include the additional options before we merge, or if we merge early and then incrementally add to it later. I have situated the class in the experimental module for now, so that we have some freedom to also change things in the future as we see fit.

[1] https://arxiv.org/abs/2206.00364

Preliminary implementation, to be extended with other variants as well.
@vpratz vpratz added the feature New feature or request label Apr 13, 2025
Copy link

codecov bot commented Apr 13, 2025

Codecov Report

Attention: Patch coverage is 60.29777% with 160 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
bayesflow/experimental/diffusion_model.py 64.85% 129 Missing ⚠️
bayesflow/utils/integrate.py 8.82% 31 Missing ⚠️
Files with missing lines Coverage Δ
bayesflow/experimental/__init__.py 100.00% <100.00%> (ø)
bayesflow/utils/__init__.py 100.00% <100.00%> (ø)
bayesflow/utils/integrate.py 44.06% <8.82%> (-8.71%) ⬇️
bayesflow/experimental/diffusion_model.py 64.85% <64.85%> (ø)

@arrjon arrjon self-assigned this Apr 14, 2025
@arrjon
Copy link
Collaborator

arrjon commented Apr 14, 2025

Thanks @vpratz for the implementation! sigma_data is already specific for the EDM version. I think, we have to make it even more general, so some additional arguments which can be passed, depending on the type of schedule one wants to use.

I plan to add on top of this additional schedules and samplers until the end of the week.

@vpratz
Copy link
Collaborator Author

vpratz commented Apr 14, 2025

Thanks for taking a look. Do you know whether your implementation would benefit from the pre-conditioning discussed in Elucidating the Design Space of Diffusion-Based Generative Models, and whether we can combine them in one joint framework?

@arrjon
Copy link
Collaborator

arrjon commented Apr 14, 2025

Part of the pre-conditioning can be expressed as a special kind of weighting function: see appendix D.1 in here.

So yes, the aim would be to have one nice framework!

@arrjon
Copy link
Collaborator

arrjon commented Apr 16, 2025

I added some more noise schedules and started to make the implementation more general. This is just a first draft, so you @vpratz get an idea, how we could do it. We should discuss this then and how to move forward.

Base automatically changed from dev to main April 22, 2025 14:37
@arrjon
Copy link
Collaborator

arrjon commented Apr 23, 2025

I added a class NoiseSchedule and different schedules, so it should be easy now to extend to more schedules if necessary. Since EDM has a specific sampling scheme for inference, this is now also defined in the noise schedule. Therefore, we do not have to specify specific sampling step sizes anymore.

Next step would be add stochastic samplers as well.

@vpratz vpratz changed the base branch from main to dev April 24, 2025 07:17
@vpratz
Copy link
Collaborator Author

vpratz commented Apr 29, 2025

Thanks a lot for the fixes, they increase the performance of the new implementation a lot. The old standalone EDM implementation seems to be a little bit better still, but the difference might be down to hyperparameter tuning. I have added the examples/experimental/Two_Moons_Diffusion_Comparison.ipynb notebook, which allows testing both implementations on the same benchmark, and plotting the results against each other. This is for two moons, but feel free to expand it with other benchmarks as well.

As far as I can tell, the open steps before we finalize this PR are:

  • optimising performance: ensuring that we achieve the same performance as with the standalone EDM implementation, so that we do not forego performance by not including it
  • related: find and set good defaults
  • remove the LinearNoiseSchedule (as discussed privately, maybe moving it to a tutorial): the other schedules perform better, so we do not need to include and maintain it here
  • add tests to cover all relevant cases/combinations
  • proof-reading docstrings and maybe supplying a tutorial
  • remove the examples/experimental/Two_Moons_Diffusion_Comparison.ipynb notebook before merging

Did I miss anything, @arrjon , or do you have any other comments on the current state?

@arrjon
Copy link
Collaborator

arrjon commented Apr 29, 2025

The performance issue is fixed now, it was mainly due to a missing scaling factor of the log_snr, which goes into the network.

I am also implementing the stochastic sampler: it is working for all backend but jax at the moment. After this, only the things @vpratz mentioned are missing.

@arrjon
Copy link
Collaborator

arrjon commented Apr 29, 2025

The stochastic sampler is now also working for jax. So all features done for the moment!

@vpratz
Copy link
Collaborator Author

vpratz commented Apr 29, 2025

Great! Thanks a lot for putting in the work and for the quick fixes! I'll try to add the relevant tests and work on some of the other missing things in the next few days.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants